{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "800a5928-7446-42e8-94ee-28a25a353a31",
   "metadata": {},
   "source": [
    "!pip install unsloth\n",
    "\n",
    "!pip install levenshtein"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cc3e684b-19f8-4a3d-be76-c44e2f234f4d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-04-07 02:59:46,335] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/compiler_compat/ld: cannot find -laio: No such file or directory\n",
      "collect2: error: ld returned 1 exit status\n",
      "/root/miniconda3/compiler_compat/ld: warning: librt.so.1, needed by /usr/local/cuda/lib64/libcufile.so, not found (try using -rpath or -rpath-link)\n",
      "/root/miniconda3/compiler_compat/ld: warning: libpthread.so.0, needed by /usr/local/cuda/lib64/libcufile.so, not found (try using -rpath or -rpath-link)\n",
      "/root/miniconda3/compiler_compat/ld: warning: libstdc++.so.6, needed by /usr/local/cuda/lib64/libcufile.so, not found (try using -rpath or -rpath-link)\n",
      "/root/miniconda3/compiler_compat/ld: warning: libm.so.6, needed by /usr/local/cuda/lib64/libcufile.so, not found (try using -rpath or -rpath-link)\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::steady_clock::now()@GLIBCXX_3.4.19'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for bool@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_logic_error(char const*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_ostringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::logic_error@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::locale::~locale()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(std::string const&, unsigned long, unsigned long)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_end_catch@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_ofstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::logic_error::~logic_error()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for __cxxabiv1::__si_class_type_info@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<char, std::char_traits<char> >::_M_cache_locale(std::locale const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_stringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `operator new[](unsigned long)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_leak_hard()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ifstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_streambuf<wchar_t, std::char_traits<wchar_t> >::basic_streambuf(std::basic_streambuf<wchar_t, std::char_traits<wchar_t> > const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::append(char const*, unsigned long)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(std::string const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned short@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::resize(unsigned long, char)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for char const*@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ctype<char>::_M_widen_init() const@GLIBCXX_3.4.11'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_invalid_argument(char const*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::locale::operator=(std::locale const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<wchar_t, std::char_traits<wchar_t> >::_M_cache_locale(std::locale const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_decrement(std::_Rb_tree_node_base const*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_free_exception@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::condition_variable::notify_one()@GLIBCXX_3.4.11'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ios_base::Init::~Init()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::~basic_string()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_pure_virtual@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::flush()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for __cxxabiv1::__class_type_info@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_rethrow@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_stringbuf<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_fstream<char, std::char_traits<char> >::~basic_fstream()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::compare(char const*) const@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_ostringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::locale::locale()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::system_clock::now()@GLIBCXX_3.4.19'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_ifstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Hash_bytes(void const*, unsigned long, unsigned long)@CXXABI_1.3.5'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<long long>(long long)@GLIBCXX_3.4.9'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for char*@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__detail::_Prime_rehash_policy::_M_need_rehash(unsigned long, unsigned long, unsigned long) const@GLIBCXX_3.4.18'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::out_of_range@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<unsigned long>(unsigned long)@GLIBCXX_3.4.9'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_increment(std::_Rb_tree_node_base const*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ios_base::~ios_base()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::range_error::~range_error()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__basic_file<char>::~__basic_file()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_guard_acquire@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<bool>(bool)@GLIBCXX_3.4.9'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::overflow_error@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_fstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::range_error@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ios<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_filebuf<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `operator delete[](void*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_stringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(unsigned long, char, std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__detail::_List_node_base::_M_transfer(std::__detail::_List_node_base*, std::__detail::_List_node_base*)@GLIBCXX_3.4.15'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::replace(unsigned long, unsigned long, char const*, unsigned long)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for std::exception@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >::_Rep::_M_destroy(std::allocator<wchar_t> const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::istream& std::istream::_M_extract<double>(double&)@GLIBCXX_3.4.9'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_filebuf<char, std::char_traits<char> >::close()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_fstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ifstream<char, std::char_traits<char> >::basic_ifstream(char const*, std::_Ios_Openmode)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::append(std::string const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `operator new(unsigned long)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_istringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned int@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::append(char const*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::domain_error@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::find(char, unsigned long) const@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::put(char)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for int@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_bad_alloc()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_thread_atexit@CXXABI_1.3.7'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned int*@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_increment(std::_Rb_tree_node_base*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ifstream<char, std::char_traits<char> >::~basic_ifstream()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ios_base::Init::Init()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::condition_variable::condition_variable()@GLIBCXX_3.4.11'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_filebuf<char, std::char_traits<char> >::basic_filebuf()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_istringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::domain_error::~domain_error()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::cerr@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::find(char const*, unsigned long, unsigned long) const@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_istringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_stringbuf<char, std::char_traits<char>, std::allocator<char> >::str() const@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::invalid_argument@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for void*@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::assign(std::string const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ostringstream<char, std::char_traits<char>, std::allocator<char> >::~basic_ostringstream()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_rebalance_for_erase(std::_Rb_tree_node_base*, std::_Rb_tree_node_base&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned long@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__detail::_List_node_base::_M_hook(std::__detail::_List_node_base*)@GLIBCXX_3.4.15'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__detail::_List_node_base::_M_unhook()@GLIBCXX_3.4.15'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ostringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_stringbuf<char, std::char_traits<char>, std::allocator<char> >::_M_sync(char*, unsigned long, unsigned long)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_iostream<char, std::char_traits<char> >::~basic_iostream()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::locale::locale(std::locale const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_istringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `log2f@GLIBC_2.2.5'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::operator<<(std::basic_streambuf<char, std::char_traits<char> >*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_streambuf<wchar_t, std::char_traits<wchar_t> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::exception::~exception()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_Rep::_S_create(unsigned long, unsigned long, std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__basic_file<char>::is_open() const@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_istringstream<char, std::char_traits<char>, std::allocator<char> >::~basic_istringstream()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::swap(std::string&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ostringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_streambuf<char, std::char_traits<char> >::basic_streambuf(std::basic_streambuf<char, std::char_traits<char> > const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<char, std::char_traits<char> >::init(std::basic_streambuf<char, std::char_traits<char> >*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_bad_cast()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<char, std::char_traits<char> >::clear(std::_Ios_Iostate)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_streambuf<wchar_t, std::char_traits<wchar_t> >::operator=(std::basic_streambuf<wchar_t, std::char_traits<wchar_t> > const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for long*@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `operator delete(void*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::operator<<(int)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_Rep::_S_empty_rep_storage@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_Rep::_M_destroy(std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_iostream<wchar_t, std::char_traits<wchar_t> >::~basic_iostream()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::runtime_error@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ofstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_insert_and_rebalance(bool, std::_Rb_tree_node_base*, std::_Rb_tree_node_base*, std::_Rb_tree_node_base&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_stringstream<char, std::char_traits<char>, std::allocator<char> >::~basic_stringstream()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_stringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<long>(long)@GLIBCXX_3.4.9'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::istream::get()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned long long@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ostream<char, std::char_traits<char> >& std::operator<< <std::char_traits<char> >(std::basic_ostream<char, std::char_traits<char> >&, char const*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::out_of_range::~out_of_range()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::length_error::~length_error()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ostream<char, std::char_traits<char> >& std::__ostream_insert<char, std::char_traits<char> >(std::basic_ostream<char, std::char_traits<char> >&, char const*, long)@GLIBCXX_3.4.9'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::invalid_argument::~invalid_argument()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >::swap(std::basic_string<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::cout@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<unsigned long long>(unsigned long long)@GLIBCXX_3.4.9'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<void const*>(void const*)@GLIBCXX_3.4.9'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::underflow_error@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_streambuf<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for std::out_of_range@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_allocate_exception@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ios<wchar_t, std::char_traits<wchar_t> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for void const*@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<wchar_t, std::char_traits<wchar_t> >::init(std::basic_streambuf<wchar_t, std::char_traits<wchar_t> >*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::reserve(unsigned long)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_begin_catch@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for long@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >::_Rep::_S_empty_rep_storage@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_leak()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_filebuf<char, std::char_traits<char> >::open(char const*, std::_Ios_Openmode)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_stringbuf<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >::_M_sync(wchar_t*, unsigned long, unsigned long)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::istream::getline(char*, long, char)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_istream<char, std::char_traits<char> >& std::getline<char, std::char_traits<char>, std::allocator<char> >(std::basic_istream<char, std::char_traits<char> >&, std::basic_string<char, std::char_traits<char>, std::allocator<char> >&, char)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_stringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::condition_variable::~condition_variable()@GLIBCXX_3.4.11'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_stringbuf<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::insert(unsigned long, char const*, unsigned long)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::assign(char const*, unsigned long)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned char@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ios_base::ios_base()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_out_of_range(char const*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::overflow_error::~overflow_error()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_length_error(char const*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_system_error(int)@GLIBCXX_3.4.11'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ofstream<char, std::char_traits<char> >::close()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<double>(double)@GLIBCXX_3.4.9'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_streambuf<char, std::char_traits<char> >::operator=(std::basic_streambuf<char, std::char_traits<char> > const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for long long@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(char const*, unsigned long, std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ifstream<char, std::char_traits<char> >::close()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_guard_release@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_throw@CXXABI_1.3'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::underflow_error::~underflow_error()@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_decrement(std::_Rb_tree_node_base*)@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::length_error@GLIBCXX_3.4'\n",
      "/root/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_filebuf<char, std::char_traits<char> >::~basic_filebuf()@GLIBCXX_3.4'\n",
      "collect2: error: ld returned 1 exit status\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a3e04a85d1cb44ed96ef9b9c4d2f14bb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import os \n",
    "os.environ[\"HF_ENDPOINT\"] = \"https://hf-mirror.com\"\n",
    "# os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"max_split_size_mb:128\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2\"\n",
    "from datasets import load_dataset,Dataset\n",
    "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel\n",
    "from qwen_vl_utils import process_vision_info\n",
    "from trl import GRPOConfig, GRPOTrainer\n",
    "from transformers import (\n",
    "AutoModelForCausalLM, \n",
    "AutoTokenizer, \n",
    "Qwen2_5_VLForConditionalGeneration, \n",
    "AutoProcessor,\n",
    "BitsAndBytesConfig\n",
    ")\n",
    "import torch\n",
    "import deepspeed\n",
    "\n",
    " \n",
    "# compute_dtype = getattr(torch, \"float16\")\n",
    "\n",
    "# bnb_config = BitsAndBytesConfig(\n",
    "#     load_in_4bit=True,\n",
    "#     bnb_4bit_use_double_quant=False,\n",
    "#     bnb_4bit_quant_type=\"nf4\",\n",
    "#     bnb_4bit_compute_dtype=compute_dtype,\n",
    "# )\n",
    "\n",
    "device_map = {\"\": int(os.environ.get(\"LOCAL_RANK\") or 0)}  \n",
    "\n",
    "tokenizer = AutoProcessor.from_pretrained(\"/root/autodl-tmp/Qwen/Qwen2.5-VL-7B-Instruct\",\n",
    "                                         use_fast=True)\n",
    "min_pixels = 256*28*28\n",
    "max_pixels = 1280*28*28\n",
    "processor = AutoProcessor.from_pretrained(\"/root/autodl-tmp/Qwen/Qwen2.5-VL-7B-Instruct\", min_pixels=min_pixels, max_pixels=max_pixels,use_fast=True)\n",
    "# use cuda device\n",
    "model = Qwen2_5_VLForConditionalGeneration.from_pretrained(\"/root/autodl-tmp/Qwen/Qwen2.5-VL-7B-Instruct\", \n",
    "                                             device_map=device_map, \n",
    "                                            torch_dtype=torch.bfloat16,\n",
    "                                            # quantization_config=bnb_config\n",
    "                                                          ).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "71b2732d-7059-40aa-bf17-78da4b3af8b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'messages': [{'role': 'user', 'content': '图中的组织类型是什么？'}, {'role': 'assistant', 'content': '图中的组织切片显示了一个鼻内肿块的病理学样本，这是从鼻腔内部取样的。'}, {'role': 'user', 'content': '放大倍率是多少？'}, {'role': 'assistant', 'content': '图像的放大倍数为200倍。'}, {'role': 'user', 'content': '根据病理学特征，诊断是什么？'}, {'role': 'assistant', 'content': '诊断为B细胞淋巴瘤。B细胞淋巴瘤起源于B淋巴细胞，这是一种白血细胞。影像中显示，大量圆形大细胞弥漫性地遮盖了呼吸道上皮下的基底膜。结合其他临床和实验室资料，这一诊断得以确立。'}], 'images': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=720x494 at 0x7F5ED85F7C80>]}\n",
      "{'prompt': [{'content': [{'text': '\\n        Below is an instruction that describes a task, paired with an input that provides further context.\\n        Write a response that appropriately completes the request.\\n        Before answering, think carefully about the question and create a step-by-step chain of \\n        thoughts to ensure a logical and accurate response.\\n        \\n        ### Instruction:\\n        You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning.\\n        Please answer the following medical question based on the input image. Output the thinking process in <think> </think> and final answer in <answer> </answer> tags.The output format should be as follows:\\n<think> ... </think> <answer>...</answer>\\n除了特殊符号，请用中文回答\\n        ', 'type': 'text'}], 'role': 'system'}, {'content': [{'text': None, 'type': 'image'}, {'text': '胸腔穿刺的目的是什么？', 'type': 'text'}], 'role': 'user'}], 'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=140x140 at 0x7F5ED84D8710>, 'solution': '胸腔穿刺是一种医疗操作，用于从胸膜腔（即肺和胸壁之间的区域）抽出过多的积液。其主要目的是缓解由积液导致的呼吸困难和胸痛症状，诊断胸腔积液的原因，并治疗如感染或恶性肿瘤等特定情况。通过排出多余的液体，胸腔穿刺能提升肺功能，舒缓症状，以便进行进一步的诊断和治疗。'}\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from datasets import load_dataset,Dataset\n",
    "from PIL import Image\n",
    "import base64\n",
    "from io import BytesIO\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    " \n",
    "ds = load_dataset(\"BUAADreamer/llava-med-zh-instruct-60k\",split = \"train[0:2000]\", trust_remote_code=True,cache_dir='./data')\n",
    "print(ds[0])# show content\n",
    " \n",
    " \n",
    "def get_prompt_rft(example):\n",
    "    '''\n",
    "    input: dict example, including PIL image object\n",
    "    output: multiple samples, within a dict format\n",
    "    '''\n",
    "    dialogue_num = len(example['messages'])\n",
    "    i = 0\n",
    "    results=[]\n",
    "    while i<dialogue_num:\n",
    "        assert example['messages'][i]['role']=='user' and example['messages'][i+1]['role']=='assistant'\n",
    "        question_sample = example['messages'][i]['content']\n",
    "        answer_sample = example['messages'][i+1]['content']\n",
    "        img_pil = example['images'][0].resize((140,140))  # reduce vRAM burden\n",
    "        out_results = []\n",
    "        SYSTEM_PROMPT = r'''\n",
    "        Below is an instruction that describes a task, paired with an input that provides further context.\n",
    "        Write a response that appropriately completes the request.\n",
    "        Before answering, think carefully about the question and create a step-by-step chain of \n",
    "        thoughts to ensure a logical and accurate response.\n",
    "        \n",
    "        ### Instruction:\n",
    "        You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning.\n",
    "        Please answer the following medical question based on the input image. Output the thinking process in <think> </think> and final answer in <answer> </answer> tags.The output format should be as follows:\n",
    "<think> ... </think> <answer>...</answer>\n",
    "除了特殊符号，请用中文回答\n",
    "        '''   # for a different language, please change the last few words.\n",
    "        results.append({\n",
    "                'prompt': [\n",
    "                    {'role': 'system', 'content': [{\"type\": \"text\", \"text\": SYSTEM_PROMPT}]},\n",
    "                    {'role': 'user', 'content': [\n",
    "                        {\"type\": \"image\", },  \n",
    "                        {\"type\": \"text\", \"text\": question_sample},    \n",
    "                    ]}\n",
    "                ],\n",
    "                'image':img_pil,\n",
    "                'solution':answer_sample,\n",
    "            })\n",
    "        i+=2\n",
    "    return results\n",
    " \n",
    "def dataset_gen():\n",
    "    for items in ds:\n",
    "        multiple_out = get_prompt_rft(items)\n",
    "        for single_out in multiple_out:\n",
    "            yield single_out\n",
    "my_gen = dataset_gen()\n",
    "\n",
    "\n",
    " \n",
    "dataset_train = Dataset.from_generator(dataset_gen)\n",
    "print(dataset_train[-1])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ee74d79a-eb96-4dc1-9776-c9527c907f03",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['这张图片显示的是一个组织切片的显微镜图像，染色为HE（苏木精-伊红染色）。从图像上看，细胞核呈现蓝色，细胞质呈现粉红色。这种染色方式通常用于病理学检查，以观察细胞结构和形态。\\n\\n根据图像特征，这可能是淋巴结或其他器官的组织切片。在这样的图像中，可以看到大量的淋巴样细胞，这些细胞通常是淋巴细胞或浆细胞。然而，仅凭这张图像无法做出确切的诊断。病理学家需要结合临床信息、病史和其他实验室检查结果来做出最终诊断。\\n\\n如果这是来自']\n"
     ]
    }
   ],
   "source": [
    "from qwen_vl_utils import process_vision_info\n",
    "image = ds[0]['images'][0]\n",
    "instruction = \"You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. \\\n",
    "Please answer the following medical question based on the input image. 请用中文回答\"\n",
    "# for a different language, please change the last few words.\n",
    "# messages = [\n",
    "#     {\"role\": \"user\", \"content\": [\n",
    "#         {\"type\": \"image\"},\n",
    "#         {\"type\": \"text\", \"text\": instruction}\n",
    "#     ]}\n",
    "# ]\n",
    "\n",
    "\n",
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": [\n",
    "            {\n",
    "                \"type\": \"image\",\n",
    "                \"image\": image,\n",
    "            },\n",
    "            {\n",
    "                \"type\": \"text\",\n",
    "                \"text\": instruction,\n",
    "            },\n",
    "        ],\n",
    "    }\n",
    "    \n",
    "]\n",
    "\n",
    "text = processor.apply_chat_template(\n",
    "    messages, tokenize=False, add_generation_prompt=True\n",
    ")\n",
    "\n",
    "image_inputs, video_inputs = process_vision_info(messages)\n",
    "inputs = processor(\n",
    "    text=[text],\n",
    "    images=image_inputs,\n",
    "    videos=video_inputs,\n",
    "    padding=True,\n",
    "    return_tensors=\"pt\",\n",
    "    \n",
    ")\n",
    "device = next(model.parameters()).device\n",
    "    # 将所有张量移动到指定的设备上\n",
    "for key, value in inputs.items():\n",
    "    inputs[key] = value.to(device)\n",
    "\n",
    "# Inference: Generation of the output\n",
    "generated_ids = model.generate(**inputs, max_new_tokens=128)\n",
    "generated_ids_trimmed = [\n",
    "    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n",
    "]\n",
    "output_text = processor.batch_decode(\n",
    "    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
    ")\n",
    "print(output_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "47f07208-7ac8-46c0-a34c-382483011284",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "from Levenshtein import ratio as levenshtein_ratio\n",
    "def format_reward_func(completions, **kwargs):\n",
    "    \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
    "    # print(completions) #debug\n",
    "    pattern = r\"^<think>.*?</think>.*?<answer>.*?</answer>$\"\n",
    "    matches = [re.match(pattern, content[0]['content'], re.DOTALL) for content in completions]\n",
    "    return [1.0 if match else 0.0 for match in matches]\n",
    "def levenshtein_reward_func(completions, solution, **kwargs):\n",
    "    \"\"\"Reward function that checks if the completion get solutions correctly.\"\"\"\n",
    "    res = []\n",
    "    for completion, sol in zip(completions, solution):\n",
    "        completion = completion[0]['content']\n",
    "        if '</think>' in completion:\n",
    "            t = completion.split('</think>')[-1]    # calculate result distance\n",
    "            res.append(levenshtein_ratio(t, sol))\n",
    "        else:\n",
    "            res.append(0.0)\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "987760f0-702a-4303-87cd-f791e440bbba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前工作目录: /root/autodl-tmp\n",
      "文件是否存在: True\n",
      "目录内容: ['.autodl', '._____temp', '.ipynb_checkpoints', 'Qwen', 'download.py', 'Qwen2.5VL-GRPO.ipynb', 'data', 'vlm_rft_trainer', 'outputs', 'train.py', 'ds_z2_offload_config.json', 'ds_z3_offload_config.json']\n"
     ]
    }
   ],
   "source": [
    "import subprocess\n",
    "\n",
    "\n",
    "# result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n",
    "# output = result.stdout\n",
    "# for line in output.splitlines():\n",
    "#     if '=' in line:\n",
    "#         var, value = line.split('=', 1)\n",
    "#         os.environ[var] = value\n",
    "output_dir=\"./outputs/Qwevl-Instruct-GRPO\"\n",
    "run_name=\"Qwen-vl-GRPO-medical\"\n",
    "\n",
    "from trl import GRPOConfig\n",
    "# !git clone https://github.com/auto-Dog/vlm_rft_trainer.git\n",
    "# !cd /kaggle/working/vlm_rft_trainer && git pull    # 替换成自己的地址\n",
    "# !cp /kaggle/working/vlm_rft_trainer/grpo_trainer.py grpo_trainer.py    # 替换成自己的地址\n",
    "\n",
    "# 打印当前工作目录\n",
    "print(\"当前工作目录:\", os.getcwd())\n",
    "\n",
    "# 检查文件是否存在\n",
    "file_path = \"ds_z2_offload_config.json\"\n",
    "print(\"文件是否存在:\", os.path.exists(file_path))\n",
    "\n",
    "# 列出当前目录下的文件\n",
    "print(\"目录内容:\", os.listdir())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d40e2e3e-cd3a-4ff7-8277-7a4cbd5e3725",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "配置文件路径: /root/autodl-tmp/ds_z2_offload_config.json\n",
      "文件是否存在: True\n"
     ]
    }
   ],
   "source": [
    "# 获取 Notebook 所在目录\n",
    "notebook_dir = os.path.dirname(os.path.abspath(\".\"))  # 或者用 os.getcwd()\n",
    "ds_config_path = os.path.join(notebook_dir, \"/root/autodl-tmp/ds_z2_offload_config.json\")\n",
    "\n",
    "# 检查路径是否正确\n",
    "print(\"配置文件路径:\", ds_config_path)\n",
    "print(\"文件是否存在:\", os.path.exists(ds_config_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b160b498-b8de-47cc-a4c0-556128454b1d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.12/site-packages/peft/mapping_func.py:73: UserWarning: You are trying to modify a model with PEFT for a second time. If you want to reload the model with a different config, make sure to call `.unload()` before.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.12/site-packages/peft/tuners/tuners_utils.py:167: UserWarning: Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!\n",
      "  warnings.warn(\n",
      "No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n"
     ]
    },
    {
     "ename": "OutOfMemoryError",
     "evalue": "Caught OutOfMemoryError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py\", line 96, in _worker\n    output = module(*input, **kwargs)\n             ^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1739, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1750, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/peft/peft_model.py\", line 818, in forward\n    return self.get_base_model()(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1739, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1750, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py\", line 1825, in forward\n    outputs = self.model(\n              ^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1739, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1750, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py\", line 1171, in forward\n    layer_outputs = decoder_layer(\n                    ^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1739, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1750, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py\", line 1048, in forward\n    hidden_states = self.mlp(hidden_states)\n                    ^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1739, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1750, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py\", line 603, in forward\n    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n                                           ^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1739, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1750, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/peft/tuners/lora/layer.py\", line 727, in forward\n    result = result + lora_B(lora_A(dropout(x))) * scaling\n                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~\ntorch.OutOfMemoryError: CUDA out of memory. Tried to allocate 170.00 MiB. GPU 0 has a total capacity of 31.60 GiB of which 46.12 MiB is free. Process 174057 has 31.55 GiB memory in use. Of the allocated memory 30.95 GiB is allocated by PyTorch, and 206.31 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mOutOfMemoryError\u001b[0m                          Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[10], line 55\u001b[0m\n\u001b[1;32m     19\u001b[0m training_args \u001b[38;5;241m=\u001b[39m GRPOConfig(\n\u001b[1;32m     20\u001b[0m     \u001b[38;5;66;03m# use_vllm = True, # use vLLM for fast inference!\u001b[39;00m\n\u001b[1;32m     21\u001b[0m     learning_rate \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m5e-6\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     42\u001b[0m     deepspeed\u001b[38;5;241m=\u001b[39mds_config_path\n\u001b[1;32m     43\u001b[0m )\n\u001b[1;32m     44\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Qwen2VLGRPOTrainer(\n\u001b[1;32m     45\u001b[0m     model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m     46\u001b[0m     processing_class\u001b[38;5;241m=\u001b[39mtokenizer,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     52\u001b[0m     peft_config \u001b[38;5;241m=\u001b[39m peft_config,\n\u001b[1;32m     53\u001b[0m )\n\u001b[0;32m---> 55\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     57\u001b[0m trainer\u001b[38;5;241m.\u001b[39msave_model(output_dir)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:2245\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   2243\u001b[0m         hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m   2244\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2245\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2246\u001b[0m \u001b[43m        \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2247\u001b[0m \u001b[43m        \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2248\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2249\u001b[0m \u001b[43m        \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2250\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:2560\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2553\u001b[0m context \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m   2554\u001b[0m     functools\u001b[38;5;241m.\u001b[39mpartial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mno_sync, model\u001b[38;5;241m=\u001b[39mmodel)\n\u001b[1;32m   2555\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(batch_samples) \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m   2556\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mdistributed_type \u001b[38;5;241m!=\u001b[39m DistributedType\u001b[38;5;241m.\u001b[39mDEEPSPEED\n\u001b[1;32m   2557\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m contextlib\u001b[38;5;241m.\u001b[39mnullcontext\n\u001b[1;32m   2558\u001b[0m )\n\u001b[1;32m   2559\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[0;32m-> 2560\u001b[0m     tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2562\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m   2563\u001b[0m     args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m   2564\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m   2565\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m   2566\u001b[0m ):\n\u001b[1;32m   2567\u001b[0m     \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m   2568\u001b[0m     tr_loss \u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m+\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:3736\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs, num_items_in_batch)\u001b[0m\n\u001b[1;32m   3733\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m   3735\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 3736\u001b[0m     loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3738\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m inputs\n\u001b[1;32m   3739\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m   3740\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mtorch_empty_cache_steps \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   3741\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mtorch_empty_cache_steps \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m   3742\u001b[0m ):\n",
      "File \u001b[0;32m~/autodl-tmp/vlm_rft_trainer/grpo_trainer.py:427\u001b[0m, in \u001b[0;36mQwen2VLGRPOTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs, num_items_in_batch)\u001b[0m\n\u001b[1;32m    424\u001b[0m pixel_values \u001b[38;5;241m=\u001b[39m prompt_inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpixel_values\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mrepeat_interleave(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_generations, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, pixel_values\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])  \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpixel_values\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m prompt_inputs \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    425\u001b[0m image_grid_thw \u001b[38;5;241m=\u001b[39m prompt_inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimage_grid_thw\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mrepeat_interleave(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_generations, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimage_grid_thw\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m prompt_inputs \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 427\u001b[0m per_token_logps \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_per_token_logps\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompt_completion_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpixel_values\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimage_grid_thw\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    428\u001b[0m \u001b[38;5;66;03m# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)\u001b[39;00m\n\u001b[1;32m    429\u001b[0m per_token_logps \u001b[38;5;241m=\u001b[39m per_token_logps[:, prompt_length \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m :]\n",
      "File \u001b[0;32m~/autodl-tmp/vlm_rft_trainer/grpo_trainer.py:358\u001b[0m, in \u001b[0;36mQwen2VLGRPOTrainer._get_per_token_logps\u001b[0;34m(self, model, input_ids, attention_mask, pixel_values, image_grid_thw)\u001b[0m\n\u001b[1;32m    357\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_get_per_token_logps\u001b[39m(\u001b[38;5;28mself\u001b[39m, model, input_ids, attention_mask, pixel_values, image_grid_thw):\n\u001b[0;32m--> 358\u001b[0m     logits \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpixel_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpixel_values\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimage_grid_thw\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mimage_grid_thw\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mlogits  \u001b[38;5;66;03m# (B, L, V)\u001b[39;00m\n\u001b[1;32m    359\u001b[0m     logits \u001b[38;5;241m=\u001b[39m logits[:, :\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, :]  \u001b[38;5;66;03m# (B, L-1, V), exclude the last logit: it corresponds to the next token pred\u001b[39;00m\n\u001b[1;32m    360\u001b[0m     input_ids \u001b[38;5;241m=\u001b[39m input_ids[:, \u001b[38;5;241m1\u001b[39m:]  \u001b[38;5;66;03m# (B, L-1), exclude the first input ID since we don't have logits for it\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1739\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1737\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1738\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1739\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1750\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1745\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1746\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1747\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1748\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1749\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1750\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1752\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1753\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py:193\u001b[0m, in \u001b[0;36mDataParallel.forward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m    191\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule(\u001b[38;5;241m*\u001b[39minputs[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodule_kwargs[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m    192\u001b[0m replicas \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreplicate(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice_ids[: \u001b[38;5;28mlen\u001b[39m(inputs)])\n\u001b[0;32m--> 193\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparallel_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodule_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgather(outputs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_device)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py:212\u001b[0m, in \u001b[0;36mDataParallel.parallel_apply\u001b[0;34m(self, replicas, inputs, kwargs)\u001b[0m\n\u001b[1;32m    209\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mparallel_apply\u001b[39m(\n\u001b[1;32m    210\u001b[0m     \u001b[38;5;28mself\u001b[39m, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any\n\u001b[1;32m    211\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m List[Any]:\n\u001b[0;32m--> 212\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mparallel_apply\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    213\u001b[0m \u001b[43m        \u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice_ids\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m    214\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py:126\u001b[0m, in \u001b[0;36mparallel_apply\u001b[0;34m(modules, inputs, kwargs_tup, devices)\u001b[0m\n\u001b[1;32m    124\u001b[0m     output \u001b[38;5;241m=\u001b[39m results[i]\n\u001b[1;32m    125\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(output, ExceptionWrapper):\n\u001b[0;32m--> 126\u001b[0m         \u001b[43moutput\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreraise\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    127\u001b[0m     outputs\u001b[38;5;241m.\u001b[39mappend(output)\n\u001b[1;32m    128\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/_utils.py:733\u001b[0m, in \u001b[0;36mExceptionWrapper.reraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    729\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m    730\u001b[0m     \u001b[38;5;66;03m# If the exception takes multiple arguments, don't try to\u001b[39;00m\n\u001b[1;32m    731\u001b[0m     \u001b[38;5;66;03m# instantiate since we don't know how to\u001b[39;00m\n\u001b[1;32m    732\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 733\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exception\n",
      "\u001b[0;31mOutOfMemoryError\u001b[0m: Caught OutOfMemoryError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py\", line 96, in _worker\n    output = module(*input, **kwargs)\n             ^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1739, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1750, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/peft/peft_model.py\", line 818, in forward\n    return self.get_base_model()(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1739, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1750, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py\", line 1825, in forward\n    outputs = self.model(\n              ^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1739, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1750, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py\", line 1171, in forward\n    layer_outputs = decoder_layer(\n                    ^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1739, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1750, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py\", line 1048, in forward\n    hidden_states = self.mlp(hidden_states)\n                    ^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1739, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1750, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py\", line 603, in forward\n    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n                                           ^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1739, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py\", line 1750, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/root/miniconda3/lib/python3.12/site-packages/peft/tuners/lora/layer.py\", line 727, in forward\n    result = result + lora_B(lora_A(dropout(x))) * scaling\n                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~\ntorch.OutOfMemoryError: CUDA out of memory. Tried to allocate 170.00 MiB. GPU 0 has a total capacity of 31.60 GiB of which 46.12 MiB is free. Process 174057 has 31.55 GiB memory in use. Of the allocated memory 30.95 GiB is allocated by PyTorch, and 206.31 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)\n"
     ]
    }
   ],
   "source": [
    "from vlm_rft_trainer.grpo_trainer import Qwen2VLGRPOTrainer # third-party trainer from open-R1\n",
    "# from grpo_trainer import Qwen2VLGRPOTrainer # third-party trainer from open-R1\n",
    "model.train()\n",
    "peft_config = LoraConfig(\n",
    "    r=32, #Rank\n",
    "    lora_alpha=16,\n",
    "    target_modules=[\n",
    "        \"q_proj\", \n",
    "        \"k_proj\", \n",
    "        \"v_proj\", \n",
    "        \"o_proj\", \n",
    "        \"gate_proj\", \n",
    "        \"up_proj\", \n",
    "        \"down_proj\"\n",
    "    ],\n",
    "    bias=\"none\",\n",
    "    lora_dropout=0.05,  # Conventional\n",
    ")\n",
    " \n",
    "training_args = GRPOConfig(\n",
    "    # use_vllm = True, # use vLLM for fast inference!\n",
    "    learning_rate = 5e-6,\n",
    "    adam_beta1 = 0.9,\n",
    "    adam_beta2 = 0.99,\n",
    "    weight_decay = 0.1,\n",
    "    warmup_ratio = 0.1,\n",
    "    lr_scheduler_type = \"cosine\",\n",
    "    optim = \"adamw_8bit\",\n",
    "    logging_steps = 1,\n",
    "    bf16 = True,\n",
    "    fp16 = False,\n",
    "    per_device_train_batch_size = 1,# keep same with num_generations\n",
    "    gradient_accumulation_steps = 16, # Increase to 4 for smoother training\n",
    "    num_generations = 4, # Decrease if out of memory\n",
    "    max_prompt_length = 2048,\n",
    "    max_completion_length = 2048,\n",
    "    num_train_epochs = 1, # Set to 1 for a full training run\n",
    "    max_steps = 100,\n",
    "    save_steps = 5,\n",
    "    max_grad_norm = 0.1,\n",
    "    report_to = \"none\", # Can use Weights & Biases\n",
    "    output_dir = \"outputs\",\n",
    "    deepspeed=ds_config_path\n",
    ")\n",
    "trainer = Qwen2VLGRPOTrainer(\n",
    "    model=model,\n",
    "    processing_class=tokenizer,\n",
    "    reward_funcs=[\n",
    "        format_reward_func, # all reward functions\n",
    "        levenshtein_reward_func],\n",
    "    args=training_args,\n",
    "    train_dataset=dataset_train.select(range(2000)),\n",
    "    peft_config = peft_config,\n",
    ")\n",
    " \n",
    "trainer.train()\n",
    " \n",
    "trainer.save_model(output_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78eb1227-e29d-468e-b07f-03b2ba05d0fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "message = dataset_train[0]['prompt']\n",
    "image = dataset_train[0]['image']\n",
    "input_text = tokenizer.apply_chat_template(message, add_generation_prompt = True)\n",
    "inputs = tokenizer(\n",
    "    image,\n",
    "    input_text,\n",
    "    add_special_tokens = False,\n",
    "    return_tensors = \"pt\",\n",
    ").to(\"cuda\")\n",
    " \n",
    "from transformers import TextStreamer\n",
    "text_streamer = TextStreamer(tokenizer, skip_prompt = True)\n",
    "_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 512,\n",
    "                   use_cache = True, temperature = 1.5, min_p = 0.1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
